from torch import nn
class LSTM(nn.Module):
    def __init__(self, input_num, hid_num, layers_num, out_num, batch_first=True):
        super().__init__()
        self.l1 = nn.LSTM(
            input_size=input_num,
            hidden_size=hid_num,
            num_layers=layers_num,
            batch_first=batch_first
        )
        self.drop_out = nn.Dropout(0.1)
        self.out = nn.Linear(hid_num, out_num)

    def forward(self, data):
        flow_x = data  # B * T * D
        l_out, (h_n, c_n) = self.l1(flow_x, None)  # None表示第一次 hidden_state是0
        #         print(l_out[:, -1, :].shape)
        l_out = self.drop_out(l_out)
        out = self.out(l_out[:, -1, :])
        return out
